#include <linux/random.h>
#include <linux/timex.h>

#include "dbcache.h"
#include "kmeans.h"
#include "hc.h"
#include "util_db.h"

#define diff(a, b) (a) < (b) ? ((b) - (a)) : ((a) - (b))
#define MIN_INDEX(a, b) ((a) <= (b)) ? 0 : 1
#define MAX_LOOP_NUM 1000
#define RANDOM_SEED 0  // 0为kmeans++播种，1为随机播种

static void add_to_nearest_set(unsigned int data, long long *mass_center, int center_num);
static int find_initial_cluster(unsigned int *data, int data_num, long long *mass_center, int center_num, int init_random);
static unsigned long long random(void);
static void bubble_sort(unsigned int *x, int num);

struct timespec64 ts_start, ts_end;
struct timespec64 ts_delta;

int kmeans_cluster(struct cached_dev *dc) {
    // FUNC_SHOUT();
    int center_num;
    unsigned int *data = NULL;
    long long *mass_center = NULL; //存放质心，平均值，集合元素数
    int data_num;
    int i, flag, loop_count, j;
    int ret = 0;
    struct hotness_cluster *hc = dc->hc;

    ktime_get_boottime_ts64(&ts_start);

    center_num = N_CLUSTER;
    data = vmalloc(sizeof(unsigned int) * hc->hotness_nr);
    if (!data) {
        printk("%s: data == NULL.\n", __func__);
        goto err;
    }
    mass_center = kmalloc(sizeof(long long) * center_num * 3, GFP_KERNEL); 
    if (!mass_center) {
        printk("%s: mass_center == NULL.\n", __func__);
        goto err;
    }

    data_num = 0;
    for (i = 0; i < hc->hotness_nr; i++) {
        if (hc->hotness_array[i]) {
            data[data_num++] = hc->hotness_array[i];
        }
    }
    printk("%s: data_num = %d.\n", __func__, data_num);
    if (data_num == 0) {
        goto err;
    }
    hc->count = data_num;

    if (find_initial_cluster(data, data_num, mass_center, center_num, RANDOM_SEED)) {
        printk("%s: find_initial_cluster error.\n", __func__);
        goto err;
    }

    flag = 1;
    loop_count = 0;
    while (flag == 1 && loop_count < MAX_LOOP_NUM)
    {
        flag = 0;
        ++loop_count;

        for (i = 0; i < center_num; ++i)
        {
            mass_center[i * 3 + 1] = 0;
            mass_center[i * 3 + 2] = 0;
        }
        for (j = 0; j < data_num; ++j)
            add_to_nearest_set(data[j], mass_center, center_num);
        for (i = 0; i < center_num; ++i)
        {
            if (mass_center[i * 3 + 2] == 0)
                continue;
            if (mass_center[i * 3] != mass_center[i * 3 + 1] / mass_center[i * 3 + 2])
            {
                flag = 1;
                mass_center[i * 3] = mass_center[i * 3 + 1] / mass_center[i * 3 + 2];
            }
        }
    }
    for (i = 0; i < center_num; ++i)
        hc->centers[i] = (unsigned int)mass_center[i * 3];
    bubble_sort(hc->centers, center_num);

    printk("centers: %u, %u\n", hc->centers[0], hc->centers[1]);

out:
    if (!data)
        vfree(data);
    if (!mass_center)
        kfree(mass_center);

    ktime_get_boottime_ts64(&ts_end);
    ts_delta = timespec64_sub(ts_end, ts_start);
    // printk("%s: time consumed: %lld (ns)\n", __func__, timespec64_to_ns(&ts_delta));

    return ret;
err:
    ret = -1;
    goto out;
}

unsigned int kmeans_type(struct cached_dev *dc, __u32 hotness) {
    return MIN_INDEX(diff(hotness, dc->hc->centers[0]), diff(hotness, dc->hc->centers[1]));
}

static int find_initial_cluster(unsigned int *data, int data_num, long long *mass_center, int center_num, int init_random)
{
    int i, j, k;
    unsigned int *distance;
    unsigned long long total_distance;
    unsigned long long threshold;
    unsigned long long distance_sum;
    //随机播种s
    if (init_random == 1)
    {
random_seed:
        for (i = 0; i < center_num; ++i)
            mass_center[i * 3] = data[(int)(random() % data_num)];
        return 0;
    }
    // kmeans++播种
    mass_center[0] = data[(int)(random() % data_num)];
    distance = vmalloc(sizeof(unsigned int) * data_num);
    if (!distance) {
        printk("In %s: distance == NULL, data_num = %d.\n", __func__, data_num);
        return -1;
    }
    for (k = 1; k < center_num; ++k)
    {
        total_distance = 0;
        //求每一个元素到当前所有质心的距离
        for (j = 0; j < data_num; ++j)
        {
            distance[j] = 0;
            for (i = 0; i < k; i++)
                distance[j] += diff(mass_center[i * 3], data[j]);
            total_distance += distance[j];
        }
        //距离当前质心越远的元素更有可能被选为质心
        if (total_distance == 0) goto random_seed;
        threshold = random() % total_distance;
        distance_sum = 0;
        for (j = 0; j < data_num; ++j)
        {
            distance_sum += distance[j];
            if (distance_sum >= threshold)
                break;
        }
        //产生了新的质心
        mass_center[k * 3] = data[j];
    }
    vfree(distance);
    return 0;
}

static unsigned long long random(void)
{
    unsigned long long x;
    get_random_bytes(&x, sizeof(x));
    return x;
}

static void add_to_nearest_set(unsigned int data, long long *mass_center, int center_num)
{
    /*
     * 将输入的参数点寻找最近的质心，并加入质心的函数中
     */
    unsigned int min = diff(mass_center[0], data);
    int position = 0, i;
    for (i = 1; i < center_num; i++)
    {
        unsigned int temp = diff(mass_center[i * 3], data);
        if (temp < min)
        {
            min = temp;
            position = i;
        }
    }
    mass_center[position * 3 + 1] += data;
    ++mass_center[position * 3 + 2];
}

static void bubble_sort(unsigned int *x, int num)
{
    int temp, i, j;
    for (i = 0; i < num - 1; ++i)
        for (j = 0; j < num - 1 - i; ++j)
            if (x[j] > x[j + 1])
            {
                temp = x[j + 1];
                x[j + 1] = x[j];
                x[j] = temp;
            }
    return;
}
